{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Scientific Deep Dive Into SageMaker LDA\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "1. [Setup](#Setup)\n",
    "1. [Data Exploration](#DataExploration)\n",
    "1. [Training](#Training)\n",
    "1. [Inference](#Inference)\n",
    "1. [Epilogue](#Epilogue)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "***\n",
    "\n",
    "Amazon SageMaker LDA is an unsupervised learning algorithm that attempts to describe a set of observations as a mixture of distinct categories. Latent Dirichlet Allocation (LDA) is most commonly used to discover a user-specified number of topics shared by documents within a text corpus. Here each observation is a document, the features are the presence (or occurrence count) of each word, and the categories are the topics. Since the method is unsupervised, the topics are not specified up front, and are not guaranteed to align with how a human may naturally categorize documents. The topics are learned as a probability distribution over the words that occur in each document. Each document, in turn, is described as a mixture of topics.\n",
    "\n",
    "This notebook is similar to **LDA-Introduction.ipynb** but its objective and scope are a different. We will be taking a deeper dive into the theory. The primary goals of this notebook are,\n",
    "\n",
    "* to understand the LDA model and the example dataset,\n",
    "* understand how the Amazon SageMaker LDA algorithm works,\n",
    "* interpret the meaning of the inference output.\n",
    "\n",
    "Former knowledge of LDA is not required. However, we will run through concepts rather quickly and at least a foundational knowledge of mathematics or machine learning is recommended. Suggested references are provided, as appropriate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!conda install -y scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import os, re, tarfile\n",
    "\n",
    "import boto3\n",
    "import matplotlib.pyplot as plt\n",
    "import mxnet as mx\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=3, suppress=True)\n",
    "\n",
    "# some helpful utility functions are defined in the Python module\n",
    "# \"generate_example_data\" located in the same directory as this\n",
    "# notebook\n",
    "from generate_example_data import (\n",
    "    generate_griffiths_data, match_estimated_topics,\n",
    "    plot_lda, plot_lda_topics)\n",
    "\n",
    "# accessing the SageMaker Python SDK\n",
    "import sagemaker\n",
    "from sagemaker.amazon.common import numpy_to_record_serializer\n",
    "from sagemaker.predictor import csv_serializer, json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "***\n",
    "\n",
    "*This notebook was created and tested on an ml.m4.xlarge notebook instance.*\n",
    "\n",
    "We first need to specify some AWS credentials; specifically data locations and access roles. This is the only cell of this notebook that you will need to edit. In particular, we need the following data:\n",
    "\n",
    "* `bucket` - An S3 bucket accessible by this account.\n",
    "  * Used to store input training data and model data output.\n",
    "  * Should be withing the same region as this notebook instance, training, and hosting.\n",
    "* `prefix` - The location in the bucket where this notebook's input and and output data will be stored. (The default value is sufficient.)\n",
    "* `role` - The IAM Role ARN used to give training and hosting access to your data.\n",
    "  * See documentation on how to create these.\n",
    "  * The script below will try to determine an appropriate Role ARN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()\n",
    "\n",
    "bucket = '<your_s3_bucket_name_here>'\n",
    "prefix = 'sagemaker/DEMO-lda-science'\n",
    "\n",
    "\n",
    "print('Training input/output will be stored in {}/{}'.format(bucket, prefix))\n",
    "print('\\nIAM Role: {}'.format(role))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The LDA Model\n",
    "\n",
    "As mentioned above, LDA is a model for discovering latent topics describing a collection of documents. In this section we will give a brief introduction to the model. Let,\n",
    "\n",
    "* $M$ = the number of *documents* in a corpus\n",
    "* $N$ = the average *length* of a document.\n",
    "* $V$ = the size of the *vocabulary* (the total number of unique words)\n",
    "\n",
    "We denote a *document* by a vector $w \\in \\mathbb{R}^V$ where $w_i$ equals the number of times the $i$th word in the vocabulary occurs within the document. This is called the \"bag-of-words\" format of representing a document.\n",
    "\n",
    "$$\n",
    "\\underbrace{w}_{\\text{document}} = \\overbrace{\\big[ w_1, w_2, \\ldots, w_V \\big] }^{\\text{word counts}},\n",
    "\\quad\n",
    "V = \\text{vocabulary size}\n",
    "$$\n",
    "\n",
    "The *length* of a document is equal to the total number of words in the document: $N_w = \\sum_{i=1}^V w_i$.\n",
    "\n",
    "An LDA model is defined by two parameters: a topic-word distribution matrix $\\beta \\in \\mathbb{R}^{K \\times V}$ and a  Dirichlet topic prior $\\alpha \\in \\mathbb{R}^K$. In particular, let,\n",
    "\n",
    "$$\\beta = \\left[ \\beta_1, \\ldots, \\beta_K \\right]$$\n",
    "\n",
    "be a collection of $K$ *topics* where each topic $\\beta_k \\in \\mathbb{R}^V$ is represented as probability distribution over the vocabulary. One of the utilities of the LDA model is that a given word is allowed to appear in multiple topics with positive probability. The Dirichlet topic prior is a vector $\\alpha \\in \\mathbb{R}^K$ such that $\\alpha_k > 0$ for all $k$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Exploration\n",
    "\n",
    "---\n",
    "\n",
    "## An Example Dataset\n",
    "\n",
    "Before explaining further let's get our hands dirty with an example dataset. The following synthetic data comes from [1] and comes with a very useful visual interpretation.\n",
    "\n",
    "> [1] Thomas Griffiths and Mark Steyvers. *Finding Scientific Topics.* Proceedings of the National Academy of Science, 101(suppl 1):5228-5235, 2004."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('Generating example data...')\n",
    "num_documents = 6000\n",
    "known_alpha, known_beta, documents, topic_mixtures = generate_griffiths_data(\n",
    "    num_documents=num_documents, num_topics=10)\n",
    "num_topics, vocabulary_size = known_beta.shape\n",
    "\n",
    "\n",
    "# separate the generated data into training and tests subsets\n",
    "num_documents_training = int(0.9*num_documents)\n",
    "num_documents_test = num_documents - num_documents_training\n",
    "\n",
    "documents_training = documents[:num_documents_training]\n",
    "documents_test = documents[num_documents_training:]\n",
    "\n",
    "topic_mixtures_training = topic_mixtures[:num_documents_training]\n",
    "topic_mixtures_test = topic_mixtures[num_documents_training:]\n",
    "\n",
    "print('documents_training.shape = {}'.format(documents_training.shape))\n",
    "print('documents_test.shape = {}'.format(documents_test.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by taking a closer look at the documents. Note that the vocabulary size of these data is $V = 25$. The average length of each document in this data set is 150. (See `generate_griffiths_data.py`.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('First training document =\\n{}'.format(documents_training[0]))\n",
    "print('\\nVocabulary size = {}'.format(vocabulary_size))\n",
    "print('Length of first document = {}'.format(documents_training[0].sum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "average_document_length = documents.sum(axis=1).mean()\n",
    "print('Observed average document length = {}'.format(average_document_length))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example data set above also returns the LDA parameters,\n",
    "\n",
    "$$(\\alpha, \\beta)$$\n",
    "\n",
    "used to generate the documents. Let's examine the first topic and verify that it is a probability distribution on the vocabulary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('First topic =\\n{}'.format(known_beta[0]))\n",
    "\n",
    "print('\\nTopic-word probability matrix (beta) shape: (num_topics, vocabulary_size) = {}'.format(known_beta.shape))\n",
    "print('\\nSum of elements of first topic = {}'.format(known_beta[0].sum()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unlike some clustering algorithms, one of the versatilities of the LDA model is that a given word can belong to multiple topics. The probability of that word occurring in each topic may differ, as well. This is reflective of real-world data where, for example, the word *\"rover\"* appears in a *\"dogs\"* topic as well as in a *\"space exploration\"* topic.\n",
    "\n",
    "In our synthetic example dataset, the first word in the vocabulary belongs to both Topic #1 and Topic #6 with non-zero probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('Topic #1:\\n{}'.format(known_beta[0]))\n",
    "print('Topic #6:\\n{}'.format(known_beta[5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Human beings are visual creatures, so it might be helpful to come up with a visual representation of these documents.\n",
    "\n",
    "In the below plots, each pixel of a document represents a word. The greyscale intensity is a measure of how frequently that word occurs within the document. Below we plot the first few documents of the training set reshaped into 5x5 pixel grids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig = plot_lda(documents_training, nrows=3, ncols=4, cmap='gray_r', with_colorbar=True)\n",
    "fig.suptitle('$w$ - Document Word Counts')\n",
    "fig.set_dpi(160)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When taking a close look at these documents we can see some patterns in the word distributions suggesting that, perhaps, each topic represents a \"column\" or \"row\" of words with non-zero probability and that each document is composed primarily of a handful of topics.\n",
    "\n",
    "Below we plots the *known* topic-word probability distributions, $\\beta$. Similar to the documents we reshape each probability distribution to a $5 \\times 5$ pixel image where the color represents the probability of that each word occurring in the topic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig = plot_lda(known_beta, nrows=1, ncols=10)\n",
    "fig.suptitle(r'Known $\\beta$ - Topic-Word Probability Distributions')\n",
    "fig.set_dpi(160)\n",
    "fig.set_figheight(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These 10 topics were used to generate the document corpus. Next, we will learn about how this is done."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Documents\n",
    "\n",
    "LDA is a generative model, meaning that the LDA parameters $(\\alpha, \\beta)$ are used to construct documents word-by-word by drawing from the topic-word distributions. In fact, looking closely at the example documents above you can see that some documents sample more words from some topics than from others.\n",
    "\n",
    "LDA works as follows: given \n",
    "\n",
    "* $M$ documents $w^{(1)}, w^{(2)}, \\ldots, w^{(M)}$,\n",
    "* an average document length of $N$,\n",
    "* and an LDA model $(\\alpha, \\beta)$.\n",
    "\n",
    "**For** each document, $w^{(m)}$:\n",
    "* sample a topic mixture: $\\theta^{(m)} \\sim \\text{Dirichlet}(\\alpha)$\n",
    "* **For** each word $n$ in the document:\n",
    "  * Sample a topic $z_n^{(m)} \\sim \\text{Multinomial}\\big( \\theta^{(m)} \\big)$\n",
    "  * Sample a word from this topic, $w_n^{(m)} \\sim \\text{Multinomial}\\big( \\beta_{z_n^{(m)}} \\; \\big)$\n",
    "  * Add to document\n",
    "\n",
    "The [plate notation](https://en.wikipedia.org/wiki/Plate_notation) for the LDA model, introduced in [2], encapsulates this process pictorially.\n",
    "\n",
    "![](http://scikit-learn.org/stable/_images/lda_model_graph.png)\n",
    "\n",
    "> [2] David M Blei, Andrew Y Ng, and Michael I Jordan. Latent Dirichlet Allocation. Journal of Machine Learning Research, 3(Jan):993–1022, 2003."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Topic Mixtures\n",
    "\n",
    "For the documents we generated above lets look at their corresponding topic mixtures, $\\theta \\in \\mathbb{R}^K$. The topic mixtures represent the probablility that a given word of the document is sampled from a particular topic. For example, if the topic mixture of an input document $w$ is,\n",
    "\n",
    "$$\\theta = \\left[ 0.3, 0.2, 0, 0.5, 0, \\ldots, 0 \\right]$$\n",
    "\n",
    "then $w$ is 30% generated from the first topic, 20% from the second topic, and 50% from the fourth topic. In particular, the words contained in the document are sampled from the first topic-word probability distribution 30% of the time, from the second distribution 20% of the time, and the fourth disribution 50% of the time.\n",
    "\n",
    "\n",
    "The objective of inference, also known as scoring, is to determine the most likely topic mixture of a given input document. Colloquially, this means figuring out which topics appear within a given document and at what ratios. We will perform infernece later in the [Inference](#Inference) section.\n",
    "\n",
    "Since we generated these example documents using the LDA model we know the topic mixture generating them. Let's examine these topic mixtures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('First training document =\\n{}'.format(documents_training[0]))\n",
    "print('\\nVocabulary size = {}'.format(vocabulary_size))\n",
    "print('Length of first document = {}'.format(documents_training[0].sum()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('First training document topic mixture =\\n{}'.format(topic_mixtures_training[0]))\n",
    "print('\\nNumber of topics = {}'.format(num_topics))\n",
    "print('sum(theta) = {}'.format(topic_mixtures_training[0].sum()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We plot the first document along with its topic mixture. We also plot the topic-word probability distributions again for reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig, (ax1,ax2) = plt.subplots(2, 1)\n",
    "\n",
    "ax1.matshow(documents[0].reshape(5,5), cmap='gray_r')\n",
    "ax1.set_title(r'$w$ - Document', fontsize=20)\n",
    "ax1.set_xticks([])\n",
    "ax1.set_yticks([])\n",
    "\n",
    "cax2 = ax2.matshow(topic_mixtures[0].reshape(1,-1), cmap='Reds', vmin=0, vmax=1)\n",
    "cbar = fig.colorbar(cax2, orientation='horizontal')\n",
    "ax2.set_title(r'$\\theta$ - Topic Mixture', fontsize=20)\n",
    "ax2.set_xticks([])\n",
    "ax2.set_yticks([])\n",
    "\n",
    "fig.set_dpi(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# pot\n",
    "fig = plot_lda(known_beta, nrows=1, ncols=10)\n",
    "fig.suptitle(r'Known $\\beta$ - Topic-Word Probability Distributions')\n",
    "fig.set_dpi(160)\n",
    "fig.set_figheight(1.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's plot several documents with their corresponding topic mixtures. We can see how topics with large weight in the document lead to more words in the document within the corresponding \"row\" or \"column\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig = plot_lda_topics(documents_training, 3, 4, topic_mixtures=topic_mixtures)\n",
    "fig.suptitle(r'$(w,\\theta)$ - Documents with Known Topic Mixtures')\n",
    "fig.set_dpi(160)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n",
    "\n",
    "***\n",
    "\n",
    "In this section we will give some insight into how AWS SageMaker LDA fits an LDA model to a corpus, create an run a SageMaker LDA training job, and examine the output trained model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Topic Estimation using Tensor Decompositions\n",
    "\n",
    "Given a document corpus, Amazon SageMaker LDA uses a spectral tensor decomposition technique to determine the LDA model $(\\alpha, \\beta)$ which most likely describes the corpus. See [1] for a primary reference of the theory behind the algorithm. The spectral decomposition, itself, is computed using the CPDecomp algorithm described in [2].\n",
    "\n",
    "The overall idea is the following: given a corpus of documents $\\mathcal{W} = \\{w^{(1)}, \\ldots, w^{(M)}\\}, \\; w^{(m)} \\in \\mathbb{R}^V,$ we construct a statistic tensor,\n",
    "\n",
    "$$T \\in \\bigotimes^3 \\mathbb{R}^V$$\n",
    "\n",
    "such that the spectral decomposition of the tensor is approximately the LDA parameters $\\alpha \\in \\mathbb{R}^K$ and $\\beta \\in \\mathbb{R}^{K \\times V}$ which maximize the likelihood of observing the corpus for a given number of topics, $K$,\n",
    "\n",
    "$$T \\approx \\sum_{k=1}^K \\alpha_k \\; (\\beta_k \\otimes \\beta_k \\otimes \\beta_k)$$\n",
    "\n",
    "This statistic tensor encapsulates information from the corpus such as the document mean, cross correlation, and higher order statistics. For details, see [1].\n",
    "\n",
    "\n",
    "> [1] Animashree Anandkumar, Rong Ge, Daniel Hsu, Sham Kakade, and Matus Telgarsky. *\"Tensor Decompositions for Learning Latent Variable Models\"*, Journal of Machine Learning Research, 15:2773–2832, 2014.\n",
    ">\n",
    "> [2] Tamara Kolda and Brett Bader. *\"Tensor Decompositions and Applications\"*. SIAM Review, 51(3):455–500, 2009.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Store Data on S3\n",
    "\n",
    "Before we run training we need to prepare the data.\n",
    "\n",
    "A SageMaker training job needs access to training data stored in an S3 bucket. Although training can accept data of various formats we convert the documents MXNet RecordIO Protobuf format before uploading to the S3 bucket defined at the beginning of this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# convert documents_training to Protobuf RecordIO format\n",
    "recordio_protobuf_serializer = numpy_to_record_serializer()\n",
    "fbuffer = recordio_protobuf_serializer(documents_training)\n",
    "\n",
    "# upload to S3 in bucket/prefix/train\n",
    "fname = 'lda.data'\n",
    "s3_object = os.path.join(prefix, 'train', fname)\n",
    "boto3.Session().resource('s3').Bucket(bucket).Object(s3_object).upload_fileobj(fbuffer)\n",
    "\n",
    "s3_train_data = 's3://{}/{}'.format(bucket, s3_object)\n",
    "print('Uploaded data to S3: {}'.format(s3_train_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we specify a Docker container containing the SageMaker LDA algorithm. For your convenience, a region-specific container is automatically chosen for you to minimize cross-region data communication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "\n",
    "region_name = boto3.Session().region_name\n",
    "container = get_image_uri(boto3.Session().region_name, 'lda')\n",
    "\n",
    "print('Using SageMaker LDA container: {} ({})'.format(container, region_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Parameters\n",
    "\n",
    "Particular to a SageMaker LDA training job are the following hyperparameters:\n",
    "\n",
    "* **`num_topics`** - The number of topics or categories in the LDA model.\n",
    "  * Usually, this is not known a priori.\n",
    "  * In this example, howevever, we know that the data is generated by five topics.\n",
    "\n",
    "* **`feature_dim`** - The size of the *\"vocabulary\"*, in LDA parlance.\n",
    "  * In this example, this is equal 25.\n",
    "\n",
    "* **`mini_batch_size`** - The number of input training documents.\n",
    "\n",
    "* **`alpha0`** - *(optional)* a measurement of how \"mixed\" are the topic-mixtures.\n",
    "  * When `alpha0` is small the data tends to be represented by one or few topics.\n",
    "  * When `alpha0` is large the data tends to be an even combination of several or many topics.\n",
    "  * The default value is `alpha0 = 1.0`.\n",
    "\n",
    "In addition to these LDA model hyperparameters, we provide additional parameters defining things like the EC2 instance type on which training will run, the S3 bucket containing the data, and the AWS access role. Note that,\n",
    "\n",
    "* Recommended instance type: `ml.c4`\n",
    "* Current limitations:\n",
    "  * SageMaker LDA *training* can only run on a single instance.\n",
    "  * SageMaker LDA does not take advantage of GPU hardware.\n",
    "  * (The Amazon AI Algorithms team is working hard to provide these capabilities in a future release!)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the above configuration create a SageMaker client and use the client to create a training job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "session = sagemaker.Session()\n",
    "\n",
    "# specify general training job information\n",
    "lda = sagemaker.estimator.Estimator(\n",
    "    container,\n",
    "    role,\n",
    "    output_path='s3://{}/{}/output'.format(bucket, prefix),\n",
    "    train_instance_count=1,\n",
    "    train_instance_type='ml.c4.2xlarge',\n",
    "    sagemaker_session=session,\n",
    ")\n",
    "\n",
    "# set algorithm-specific hyperparameters\n",
    "lda.set_hyperparameters(\n",
    "    num_topics=num_topics,\n",
    "    feature_dim=vocabulary_size,\n",
    "    mini_batch_size=num_documents_training,\n",
    "    alpha0=1.0,\n",
    ")\n",
    "\n",
    "# run the training job on input data stored in S3\n",
    "lda.fit({'train': s3_train_data})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you see the message\n",
    "\n",
    "> `===== Job Complete =====`\n",
    "\n",
    "at the bottom of the output logs then that means training sucessfully completed and the output LDA model was stored in the specified output path. You can also view information about and the status of a training job using the AWS SageMaker console. Just click on the \"Jobs\" tab and select training job matching the training job name, below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('Training job name: {}'.format(lda.latest_training_job.job_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inspecting the Trained Model\n",
    "\n",
    "We know the LDA parameters $(\\alpha, \\beta)$ used to generate the example data. How does the learned model compare the known one? In this section we will download the model data and measure how well SageMaker LDA did in learning the model.\n",
    "\n",
    "First, we download the model data. SageMaker will output the model in \n",
    "\n",
    "> `s3://<bucket>/<prefix>/output/<training job name>/output/model.tar.gz`.\n",
    "\n",
    "SageMaker LDA stores the model as a two-tuple $(\\alpha, \\beta)$ where each LDA parameter is an MXNet NDArray."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# download and extract the model file from S3\n",
    "job_name = lda.latest_training_job.job_name\n",
    "model_fname = 'model.tar.gz'\n",
    "model_object = os.path.join(prefix, 'output', job_name, 'output', model_fname)\n",
    "boto3.Session().resource('s3').Bucket(bucket).Object(model_object).download_file(fname)\n",
    "with tarfile.open(fname) as tar:\n",
    "    tar.extractall()\n",
    "print('Downloaded and extracted model tarball: {}'.format(model_object))\n",
    "\n",
    "# obtain the model file\n",
    "model_list = [fname for fname in os.listdir('.') if fname.startswith('model_')]\n",
    "model_fname = model_list[0]\n",
    "print('Found model file: {}'.format(model_fname))\n",
    "\n",
    "# get the model from the model file and store in Numpy arrays\n",
    "alpha, beta = mx.ndarray.load(model_fname)\n",
    "learned_alpha_permuted = alpha.asnumpy()\n",
    "learned_beta_permuted = beta.asnumpy()\n",
    "\n",
    "print('\\nLearned alpha.shape = {}'.format(learned_alpha_permuted.shape))\n",
    "print('Learned beta.shape = {}'.format(learned_beta_permuted.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Presumably, SageMaker LDA has found the topics most likely used to generate the training corpus. However, even if this is case the topics would not be returned in any particular order. Therefore, we match the found topics to the known topics closest in L1-norm in order to find the topic permutation.\n",
    "\n",
    "Note that we will use the `permutation` later during inference to match known topic mixtures to found topic mixtures.\n",
    "\n",
    "Below plot the known topic-word probability distribution, $\\beta \\in \\mathbb{R}^{K \\times V}$ next to the distributions found by SageMaker LDA as well as the L1-norm errors between the two."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "permutation, learned_beta = match_estimated_topics(known_beta, learned_beta_permuted)\n",
    "learned_alpha = learned_alpha_permuted[permutation]\n",
    "\n",
    "fig = plot_lda(np.vstack([known_beta, learned_beta]), 2, 10)\n",
    "fig.set_dpi(160)\n",
    "fig.suptitle('Known vs. Found Topic-Word Probability Distributions')\n",
    "fig.set_figheight(3)\n",
    "\n",
    "beta_error = np.linalg.norm(known_beta - learned_beta, 1)\n",
    "alpha_error = np.linalg.norm(known_alpha - learned_alpha, 1)\n",
    "print('L1-error (beta) = {}'.format(beta_error))\n",
    "print('L1-error (alpha) = {}'.format(alpha_error))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad!\n",
    "\n",
    "In the eyeball-norm the topics match quite well. In fact, the topic-word distribution error is approximately 2%."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "***\n",
    "\n",
    "A trained model does nothing on its own. We now want to use the model we computed to perform inference on data. For this example, that means predicting the topic mixture representing a given document.\n",
    "\n",
    "We create an inference endpoint using the SageMaker Python SDK `deploy()` function from the job we defined above. We specify the instance type where inference is computed as well as an initial number of instances to spin up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lda_inference = lda.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type='ml.m4.xlarge',  # LDA inference may work better at scale on ml.c4 instances\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations! You now have a functioning SageMaker LDA inference endpoint. You can confirm the endpoint configuration and status by navigating to the \"Endpoints\" tab in the AWS SageMaker console and selecting the endpoint matching the endpoint name, below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('Endpoint name: {}'.format(lda_inference.endpoint))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this realtime endpoint at our fingertips we can finally perform inference on our training and test data.\n",
    "\n",
    "We can pass a variety of data formats to our inference endpoint. In this example we will demonstrate passing CSV-formatted data. Other available formats are JSON-formatted, JSON-sparse-formatter, and RecordIO Protobuf. We make use of the SageMaker Python SDK utilities `csv_serializer` and `json_deserializer` when configuring the inference endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lda_inference.content_type = 'text/csv'\n",
    "lda_inference.serializer = csv_serializer\n",
    "lda_inference.deserializer = json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We pass some test documents to the inference endpoint. Note that the serializer and deserializer will atuomatically take care of the datatype conversion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "results = lda_inference.predict(documents_test[:12])\n",
    "\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It may be hard to see but the output format of SageMaker LDA inference endpoint is a Python dictionary with the following format.\n",
    "\n",
    "```\n",
    "{\n",
    "  'predictions': [\n",
    "    {'topic_mixture': [ ... ] },\n",
    "    {'topic_mixture': [ ... ] },\n",
    "    {'topic_mixture': [ ... ] },\n",
    "    ...\n",
    "  ]\n",
    "}\n",
    "```\n",
    "\n",
    "We extract the topic mixtures, themselves, corresponding to each of the input documents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inferred_topic_mixtures_permuted = np.array([prediction['topic_mixture'] for prediction in results['predictions']])\n",
    "\n",
    "print('Inferred topic mixtures (permuted):\\n\\n{}'.format(inferred_topic_mixtures_permuted))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference Analysis\n",
    "\n",
    "Recall that although SageMaker LDA successfully learned the underlying topics which generated the sample data the topics were in a different order. Before we compare to known topic mixtures $\\theta \\in \\mathbb{R}^K$ we should also permute the inferred topic mixtures\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "inferred_topic_mixtures = inferred_topic_mixtures_permuted[:,permutation]\n",
    "\n",
    "print('Inferred topic mixtures:\\n\\n{}'.format(inferred_topic_mixtures))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot these topic mixture probability distributions alongside the known ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# create array of bar plots\n",
    "width = 0.4\n",
    "x = np.arange(10)\n",
    "\n",
    "nrows, ncols = 3, 4\n",
    "fig, ax = plt.subplots(nrows, ncols, sharey=True)\n",
    "for i in range(nrows):\n",
    "    for j in range(ncols):\n",
    "        index = i*ncols + j\n",
    "        ax[i,j].bar(x, topic_mixtures_test[index], width, color='C0')\n",
    "        ax[i,j].bar(x+width, inferred_topic_mixtures[index], width, color='C1')\n",
    "        ax[i,j].set_xticks(range(num_topics))\n",
    "        ax[i,j].set_yticks(np.linspace(0,1,5))\n",
    "        ax[i,j].grid(which='major', axis='y')\n",
    "        ax[i,j].set_ylim([0,1])\n",
    "        ax[i,j].set_xticklabels([])\n",
    "        if (i==(nrows-1)):\n",
    "            ax[i,j].set_xticklabels(range(num_topics), fontsize=7)\n",
    "        if (j==0):\n",
    "            ax[i,j].set_yticklabels([0,'',0.5,'',1.0], fontsize=7)\n",
    "        \n",
    "fig.suptitle('Known vs. Inferred Topic Mixtures')\n",
    "ax_super = fig.add_subplot(111, frameon=False)\n",
    "ax_super.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')\n",
    "ax_super.grid(False)\n",
    "ax_super.set_xlabel('Topic Index')\n",
    "ax_super.set_ylabel('Topic Probability')\n",
    "fig.set_dpi(160)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the eyeball-norm these look quite comparable.\n",
    "\n",
    "Let's be more scientific about this. Below we compute and plot the distribution of L1-errors from **all** of the test documents. Note that we send a new payload of test documents to the inference endpoint and apply the appropriate permutation to the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# create a payload containing all of the test documents and run inference again\n",
    "#\n",
    "# TRY THIS:\n",
    "#   try switching between the test data set and a subset of the training\n",
    "#   data set. It is likely that LDA inference will perform better against\n",
    "#   the training set than the holdout test set.\n",
    "#\n",
    "payload_documents = documents_test                    # Example 1\n",
    "known_topic_mixtures = topic_mixtures_test            # Example 1\n",
    "#payload_documents = documents_training[:600];         # Example 2\n",
    "#known_topic_mixtures = topic_mixtures_training[:600]  # Example 2\n",
    "\n",
    "print('Invoking endpoint...\\n')\n",
    "results = lda_inference.predict(payload_documents)\n",
    "\n",
    "inferred_topic_mixtures_permuted = np.array([prediction['topic_mixture'] for prediction in results['predictions']])\n",
    "inferred_topic_mixtures = inferred_topic_mixtures_permuted[:,permutation]\n",
    "\n",
    "print('known_topics_mixtures.shape = {}'.format(known_topic_mixtures.shape))\n",
    "print('inferred_topics_mixtures_test.shape = {}\\n'.format(inferred_topic_mixtures.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "l1_errors = np.linalg.norm((inferred_topic_mixtures - known_topic_mixtures), 1, axis=1)\n",
    "\n",
    "# plot the error freqency\n",
    "fig, ax_frequency = plt.subplots()\n",
    "bins = np.linspace(0,1,40)\n",
    "weights = np.ones_like(l1_errors)/len(l1_errors)\n",
    "freq, bins, _ = ax_frequency.hist(l1_errors, bins=50, weights=weights, color='C0')\n",
    "ax_frequency.set_xlabel('L1-Error')\n",
    "ax_frequency.set_ylabel('Frequency', color='C0')\n",
    "\n",
    "\n",
    "# plot the cumulative error\n",
    "shift = (bins[1]-bins[0])/2\n",
    "x = bins[1:] - shift\n",
    "ax_cumulative = ax_frequency.twinx()\n",
    "cumulative = np.cumsum(freq)/sum(freq)\n",
    "ax_cumulative.plot(x, cumulative, marker='o', color='C1')\n",
    "ax_cumulative.set_ylabel('Cumulative Frequency', color='C1')\n",
    "\n",
    "\n",
    "# align grids and show\n",
    "freq_ticks = np.linspace(0, 1.5*freq.max(), 5)\n",
    "freq_ticklabels = np.round(100*freq_ticks)/100\n",
    "ax_frequency.set_yticks(freq_ticks)\n",
    "ax_frequency.set_yticklabels(freq_ticklabels)\n",
    "ax_cumulative.set_yticks(np.linspace(0, 1, 5))\n",
    "ax_cumulative.grid(which='major', axis='y')\n",
    "ax_cumulative.set_ylim((0,1))\n",
    "\n",
    "\n",
    "fig.suptitle('Topic Mixutre L1-Errors')\n",
    "fig.set_dpi(110)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Machine learning algorithms are not perfect and the data above suggests this is true of SageMaker LDA. With more documents and some hyperparameter tuning we can obtain more accurate results against the known topic-mixtures.\n",
    "\n",
    "For now, let's just investigate the documents-topic mixture pairs that seem to do well as well as those that do not. Below we retreive a document and topic mixture corresponding to a small L1-error as well as one with a large L1-error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "N = 6\n",
    "\n",
    "good_idx = (l1_errors < 0.05)\n",
    "good_documents = payload_documents[good_idx][:N]\n",
    "good_topic_mixtures = inferred_topic_mixtures[good_idx][:N]\n",
    "\n",
    "poor_idx = (l1_errors > 0.3)\n",
    "poor_documents = payload_documents[poor_idx][:N]\n",
    "poor_topic_mixtures = inferred_topic_mixtures[poor_idx][:N]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig = plot_lda_topics(good_documents, 2, 3, topic_mixtures=good_topic_mixtures)\n",
    "fig.suptitle('Documents With Accurate Inferred Topic-Mixtures')\n",
    "fig.set_dpi(120)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "fig = plot_lda_topics(poor_documents, 2, 3, topic_mixtures=poor_topic_mixtures)\n",
    "fig.suptitle('Documents With Inaccurate Inferred Topic-Mixtures')\n",
    "fig.set_dpi(120)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example set the documents on which inference was not as accurate tend to have a denser topic-mixture. This makes sense when extrapolated to real-world datasets: it can be difficult to nail down which topics are represented in a document when the document uses words from a large subset of the vocabulary."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stop / Close the Endpoint\n",
    "\n",
    "Finally, we should delete the endpoint before we close the notebook.\n",
    "\n",
    "To do so execute the cell below. Alternately, you can navigate to the \"Endpoints\" tab in the SageMaker console, select the endpoint with the name stored in the variable `endpoint_name`, and select \"Delete\" from the \"Actions\" dropdown menu. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sagemaker.Session().delete_endpoint(lda_inference.endpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Epilogue\n",
    "\n",
    "---\n",
    "\n",
    "In this notebook we,\n",
    "\n",
    "* learned about the LDA model,\n",
    "* generated some example LDA documents and their corresponding topic-mixtures,\n",
    "* trained a SageMaker LDA model on a training set of documents and compared the learned model to the known model,\n",
    "* created an inference endpoint,\n",
    "* used the endpoint to infer the topic mixtures of a test input and analyzed the inference error.\n",
    "\n",
    "There are several things to keep in mind when applying SageMaker LDA to real-word data such as a corpus of text documents. Note that input documents to the algorithm, both in training and inference, need to be vectors of integers representing word counts. Each index corresponds to a word in the corpus vocabulary. Therefore, one will need to \"tokenize\" their corpus vocabulary.\n",
    "\n",
    "$$\n",
    "\\text{\"cat\"} \\mapsto 0, \\; \\text{\"dog\"} \\mapsto 1 \\; \\text{\"bird\"} \\mapsto 2, \\ldots\n",
    "$$\n",
    "\n",
    "Each text document then needs to be converted to a \"bag-of-words\" format document.\n",
    "\n",
    "$$\n",
    "w = \\text{\"cat bird bird bird cat\"} \\quad \\longmapsto \\quad w = [2, 0, 3, 0, \\ldots, 0]\n",
    "$$\n",
    "\n",
    "Also note that many real-word applications have large vocabulary sizes. It may be necessary to represent the input documents in sparse format. Finally, the use of stemming and lemmatization in data preprocessing provides several benefits. Doing so can improve training and inference compute time since it reduces the effective vocabulary size. More importantly, though, it can improve the quality of learned topic-word probability matrices and inferred topic mixtures. For example, the words *\"parliament\"*, *\"parliaments\"*, *\"parliamentary\"*, *\"parliament's\"*, and *\"parliamentarians\"* are all essentially the same word, *\"parliament\"*, but with different conjugations. For the purposes of detecting topics, such as a *\"politics\"* or *governments\"* topic, the inclusion of all five does not add much additional value as they all essentiall describe the same feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
